基于网格搜索优化的SVM分类预测
本文作者:许林丽,中南财经政法大学统计与数学学院
本文编辑:周一鸣
技术总编:孙一博
Stata and Python 数据分析
爬虫俱乐部Stata基础课程、Stata进阶课程和Python课程可在小鹅通平台查看,欢迎大家多多支持订阅!如需了解详情,可以通过课程链接(https://appbqiqpzi66527.h5.xiaoeknow.com/homepage/10)或课程二维码进行访问哦~网格搜索(GridSearch)是一种调参方法,基本原理是在所有候选的参数选择中,通过循环遍历,尝试每一种可能性,表现最好的参数就是最终的结果,因此也被称为“穷举搜索”和“暴力搜索”。 此外,使用交叉验证可以使得评分更加严谨,因此交叉验证经常与网格搜索一起结合使用,即GridSearchCV,最后可以从列出的超参数中选择最佳参数。可以看出,需要遍历所有可能的参数组合的网格搜索的缺点就是非常耗时!!特别是在处理大数据集和多参数时。
本次小编将带着大家学习如何用支持向量机对数据集进行分类预测,并且使用网格搜索进行调参,使模型更加精确。
sklearn.svm.SVC(C=1.0,kernel='rbf', degree=3, gamma='auto',coef0=0.0,shrinking=True,probability=False,tol=0.001,cache_size=200, class_weight=None,verbose=False,max_iter=-1,decision_function_shape=None,random_state=None)
C:目标函数的惩罚系数,默认值为1.0。C越大,表示在训练样本中准确率越高,但泛化能力低,即对测试数据的分类准确率降低。kernel:核函数类型,默认为‘rbf’。常用的可选参数有linear(线性核函数)、poly(多项式核函数)、rbf(径像核函数/高斯核)和sigmoid(双曲正切核函数)。degree:使用kernel为 ‘poly’时,给定多项式的项数,默认为3。若指定kernel为其他核函数则忽略该参数。gamma:表示当kernel为‘rbf’, ‘poly’或‘sigmoid’时的kernel系数,默认为 ‘auto’,即样本特征数的倒数。coef0:核函数的常数项,只有在 kernel为‘poly’或‘sigmoid’时有效,默认为0.0。shrinking:是否采用启发式,默认为True。probability:是否启用概率估计,默认为False。tol:训练结束要求的精度,默认为0.001。cache_size:指定训练所需要的内存,以MB为单位,默认为200MB。class_weight:给定各个类别的权重,默认为1。verbose:是否详细输出训练过程,默认为False。max_iter:最大迭代次数,默认为-1,表示无穷大迭代次数。decision_function_shape:多分类时选择的方式,有‘ovo’、‘ovr’和None三种,默认为None。random_state:将训练集打乱顺序时使用的伪随机数生成器的种子,默认为None。主要需要调节的参数有:C、kernel、degree、gamma、coef0。sklearn.model_selection.GridSearchCV(estimator, param_grid=None, scoring=None, cv=None, verbose=0)
estimator:参数针对的搜索对象,即所使用的分类器。param_grid:需要最优化的参数的取值,值可以是字典或者列表。scoring:模型评价标准,根据所选模型不同,评价准则不同。默认None,使用estimator的误差估计函数。n_jobs:与并行运行相关,可以提高搜索速度,取值为整数,默认为1,大于1的整数表示运行核数(但不能超过运行主机有的核数),取-1则代表使用主机所有的核数。cv:交叉验证参数,默认None,使用三折交叉验证。verbose:日志冗长度,取值为整数。取值为0,则不输出训练过程;取值为1,则偶尔输出;取值大于1,则对每个子模型都输出。常用评价标准参数说明:grid_search.best_estimator_:查看带有最优超参的搜索器的相关信息。grid_search.best_score_:查看当前最优超参情况下的得分。grid_search.best_params_:输出当前由最优的超参及其取值组成的字典。import numpy as np
import pandas as pd
from sklearn.datasets import load_breast_cancer ##导入数据集
from sklearn import svm
from sklearn.svm import SVC ##导入SVC函数
from sklearn.metrics import classification_report ##导入模型评估函数
from sklearn.model_selection import GridSearchCV ##导入网格搜索函数
from sklearn.model_selection import train_test_split ##导入划分数据集函数
from sklearn.metrics import roc_curve, auc ##用于计算roc和auc
data = load_breast_cancer() ##乳腺癌数据集
X = data.data ##数据特征
Y = data.target ##数据标签
02划分训练集、验证集和测试集验证集可以用于调整模型的超参数和用于对模型的能力进行初步评估,常用来在模型迭代训练时,验证当前模型泛化能力(准确率,召回率等),防止过拟合的现象出现,并决定如何调整超参数。对于小规模样本集,常用的分配比例是 60% 训练集(train)、20% 验证集(val)、20% 测试集(test)。x_train, x_test_val, y_train, y_test_val = train_test_split(X, Y, test_size=0.4, random_state=0)
x_val, x_test, y_val, y_test = train_test_split(x_test_val, y_test_val, test_size=0.5, random_state=0)
03数据标准化处理def zscore_normalize_features(X):
mu = np.mean(X, axis=0)
sigma = np.std(X ,axis=0)
X_norm = (X - mu) / sigma
return (X_norm)
x_train = zscore_normalize_features(x_train)
x_val = zscore_normalize_features(x_val)
x_test = zscore_normalize_features(x_test)
04模型训练机器学习中常使用精确度、 查准率、召回率以及 F1 得分作为评估分类效率的评价指标。在这里,我们使用精确度和F1得分来对模型性能进行评估。
print("————————————调参前————————————")
clf = svm.SVC()
clf.fit(x_train, y_train)
predictions = clf.predict(x_val)
print(classification_report(y_val, predictions))
params = [
{'kernel':['linear'], 'C':[1,10,100]},
{'kernel':['poly'], 'C':[1,10], 'degree':[2,3]},
{'kernel':['rbf'], 'C':[1,10,100],
'gamma':[1, 0.1, 0.01, 0.001]}]
model = GridSearchCV(SVC(), param_grid=params, cv=5)
model.fit(x_train, y_train)
print(model.best_params_)
print("————————————调参后————————————")
predictions_val = model.predict(x_val)
print(classification_report(y_val, predictions_val))
最后,利用测试集对已训练好的模型做最后评估,由结果可以看出模型精度和F1得分较高,说明模型分类效果较好。
print("————————————测试集————————————")
predictions_test = model.predict(x_test)
print(classification_report(y_test, predictions_test))
重磅福利!为了更好地服务各位同学的研究,爬虫俱乐部将在小鹅通平台上持续提供金融研究所需要的各类指标,包括上市公司十大股东、股价崩盘、投资效率、融资约束、企业避税、分析师跟踪、净资产收益率、资产回报率、国际四大审计、托宾Q值、第一大股东持股比例、账面市值比、沪深A股上市公司研究常用控制变量等一系列深加工数据,基于各交易所信息披露的数据利用Stata在实现数据实时更新的同时还将不断上线更多的数据指标。我们以最前沿的数据处理技术、最好的服务质量、最大的诚意望能助力大家的研究工作!相关数据链接,请大家访问:(https://appbqiqpzi66527.h5.xiaoeknow.com/homepage/10)或扫描二维码:
最后,我们为大家揭秘雪球网(https://xueqiu.com/)最新所展示的沪深证券和港股关注人数增长Top10。
对我们的推文累计打赏超过1000元,我们即可给您开具发票,发票类别为“咨询费”。用心做事,不负您的支持!
往期推文推荐Stata18之dtas——The new in data management
What’ new ? 速通Stata 18
【爬虫实战】Python爬取美食菜谱揭秘网络中心人物,你会是其中之一吗?考研之后,文科生需以“do”躬“do”!焕新升级!轻松获取港股、权证的历史交易数据爬虫俱乐部的精彩答疑---cntraveltime【爬虫俱乐部新命令速递】在Stata中与ChatGPT对话用`fs`命令批量获取文件夹和不同文件夹下的excel文件
自然语言处理之实例应用JSON帮手,FeHelper
最新、最热门的命令这里都有!
Python实现微信自动回复告诉python,我想“狂飙”了——线程池与异步协程为爬虫提速高级函数——map()和reduce()Stata绘制条形图的进阶用法
快来看看武汉的房价是不是又双叒叕涨了!关于我们
微信公众号“Stata and Python数据分析”分享实用的Stata、Python等软件的数据处理知识,欢迎转载、打赏。我们是由李春涛教授领导下的研究生及本科生组成的大数据处理和分析团队。
武汉字符串数据科技有限公司一直为广大用户提供数据采集和分析的服务工作,如果您有这方面的需求,请发邮件到statatraining@163.com,或者直接联系我们的数据中台总工程司海涛先生,电话:18203668525,wechat: super4ht。海涛先生曾长期在香港大学从事研究工作,现为知名985大学的博士生,爬虫俱乐部网络爬虫技术和正则表达式的课程负责人。
此外,欢迎大家踊跃投稿,介绍一些关于Stata和Python的数据处理和分析技巧。
投稿邮箱:statatraining@163.com投稿要求:1)必须原创,禁止抄袭;2)必须准确,详细,有例子,有截图;注意事项:1)所有投稿都会经过本公众号运营团队成员的审核,审核通过才可录用,一经录用,会在该推文里为作者署名,并有赏金分成。2)邮件请注明投稿,邮件名称为“投稿+推文名称”。3)应广大读者要求,现开通有偿问答服务,如果大家遇到有关数据处理、分析等问题,可以在公众号中提出,只需支付少量赏金,我们会在后期的推文里给予解答。